%--------------------------------------------------------------------------
%                  SIMULATION MAIN IMPUTATION PROCESS 
%-------------------------------------------------------------------------- 
%
% Title: Semi-parametric Selection Models for Potentially Non-ignorable Attrition...
% in Panel Studies with Refreshment Samples
% Manuscript ID PA-2013-094
% Y.SI, J.REITER AND S.HILLYGUS
%-------------------------------------------------------------------------- 
%               Bayesian semi-parametric selection models
%                   joint for (X Y) & W depends on (X Y)
%                        Simulation Study
%-------------------------------------------------------------------------- 
%%%% Author: Yajuan Si
%%%% Latest edit date: 03/31/2014
clear;clc;close all;
tic;
%%%%----------------read simulated data-----------------------------%%%%
pop=csvread('data/pop-simulation.csv'); 
%true values of quantities of interest
true_q=[mean(pop==1) mean(pop(:,1)==1& pop(:,6)==1) mean(pop(:,2)==1& pop(:,7)==1)... 
    mean(pop(:,3)==1& pop(:,8)==1) mean(pop(:,4)==1& pop(:,9)==1)...
    mean(pop(:,5)==1& pop(:,10)==1) mean(pop(:,6)==1& pop(:,10)==1)...
    mean(pop(:,7)==1& pop(:,10)==1) mean(pop(:,8)==1& pop(:,10)==1)...
    mean(pop(:,9)==1& pop(:,10)==1)];

sampletrue=csvread('data/sample-simulation.csv');  
sampledata=csvread('data/misdata-simulation.csv');  
samplew=csvread('data/samplew-simulation.csv');
r_p_index=csvread('data/panelindex-simulation.csv');
r_r_index=csvread('data/refindex-simulation.csv');
%%%%-------------------global parameters----------------------%%%%
Re=100;  % repeated times
Nr=200; % size of refreshment sample
Np=800;   % size of panel sample
N=Np+Nr;   % number of all observations            
Trun=3000; % size of MCMC sample
burn =0; thin = 1; effsize = (Trun-burn)/thin ;
H = 20; % truncation for psb
p1 = 5;  % dimension of Y1
p2 = 5; % dim of Y2
q =0; % dim of X
p = p1 + p2 + q; %total dim
d = repmat(2,[1 p]); %levels for X Y1&Y2
%%%%--------------output files for repeated study----------------------%%%%
q_mi=zeros(Re,length(true_q));u_mi=zeros(Re,length(true_q));
b_mi=zeros(Re,length(true_q));t_mi=zeros(Re,length(true_q));
low_mi=zeros(Re,length(true_q));up_mi=zeros(Re,length(true_q));
cover=zeros(Re,length(true_q));
%%%%----------------------output files for each mcmc---------------%%%%%
alpha_out = zeros(effsize,1);
beta_w_out = zeros(effsize,p+1); cell_out=zeros(effsize,p);
pi_out = zeros(effsize,H);
kout = zeros(effsize,1);   %number of unique clusters
y1_out = zeros(effsize,p1);
y2_out = zeros(effsize,p2); w_out = zeros(effsize,1);

%%%%--------------output for multiple imputation--------------------%%%%
m = 20; %no of completed datasets; 
cut = Trun - 50 * (0:1:m-1);
impdata_out = zeros(Re,N*m,p+1); 

q_prob=zeros(m,p);  u_prob=zeros(m,p); 
q_prob_j=zeros(m,p1);   u_prob_j=zeros(m,p1);
q_prob_j2=zeros(m, p2-1); u_prob_j2=zeros(m,p2-1);
%%%%----------prior specification ------------------------------------%%%%
a0=1/4; b0=1/4; %for alpha
Sigma0 = 1; %prior on beta

a.j = ones(max(d),1);  %for phi in dirichlet/beta distribution
delta1=1; delta2 = 1; lambda_tun =0.1;  %probit+shrinkage

%%%%----------global parameters & initial setting----------------------%%%%
zupdateprob = zeros(N,H);
cc_f = zeros(H,1);

%%
%%%%--------------------repeated sampling----------------------------%%%%
for re=1:Re
    imp=0;
%%%%-----------------------read data------------------------%%%%
 data_0=sampledata(sampledata(:,1)==re,2:(p+1))+1;
 w_0=samplew(re,:)';
 data_true=sampletrue(sampletrue(:,1)==re,2:(p+1))+1;
 panelindex=r_p_index(re,:)'; refindex=r_r_index(re,:)';
 

y1_0 = data_0(:,1:p1); y2_0 = data_0(:,p1+1:p);

Ncp=sum(w_0(panelindex));   % size of complete panel subsample
Nip=Np-Ncp;                  % size of incomplete panel subsample

ui_y1 = zeros(Nr, max(d));
ui_y2 = zeros(Nip, max(d));

all=[w_0(panelindex),y1_0(panelindex,:),y2_0(panelindex,:)];
all2=sortrows(all,-1);   % re-order data by W 1st-col in reverse order


W_0=[all2(:,1)',w_0(refindex)']';
Y1_0=[all2(:,2:p1+1); y1_0(refindex,:)];  % observe Y1 only for the first Np units
Y2_0=[all2(:,p1+2:p+1); y2_0(refindex,:)]; % observe Y2 only for the Ncp units and the Nr units



%%%%----------------------------initial data------------------------%%%%
 W = [W_0(1:Np)' binornd(1, Ncp/Np, Nr, 1)']';   % missing values
 Y2=Y2_0; Y1=Y1_0;
 for j=1:p1
     Y2_j_p=mean(pop(:,5+j)==1);

     Y2(:,j) = [Y2_0(1:Ncp,j); binornd(1, Y2_j_p,...
     [Nip, 1])+1; Y2_0(Np+1:N,j)];

     Y1(:,j) = [Y1_0(1:Np,j); binornd(1, mean(Y1_0(1:Np,j)==2), [Nr, 1])+1];
 end
 
 data = [Y1 Y2];
%%%%-------------------initial parameter values------------------------%%%%
rows = repmat(1:p,N,1); rows = rows(:); 
cols = data(:);

lin_idx = sub2ind([p,max(d)],rows,cols); 

phi=zeros(H,p,max(d));      % cell probabilities

for h=1:H
    for j=1:p
        for l=1:d(j)
        phi(h,j,l)=sum(data(:,j)==l)/N;
        end
    end
end

mat=zeros(p,max(d));       % for updating phi

alpha = 1;

ph = randsample(1:H,N,true); %  %latent class indicator

nus = betarnd(1,alpha,[1 H-1]);  % stick-breaking random variables
nu =zeros(1,H);
nu(1:H-1) = nus.*cumprod([1 1-nus(:,1:H-2)]); % category probabilities
nu(H)=1-sum(nu(1:H-1));


%%
%%%%--------------------------------MCMC----------------------------%%%%
for t=1:Trun
       
  % 1) -- update stick-breaking random variables -- %
    
        for h = 1:H-1
            cc_f(h) = nu(h)/nus(h); cc_f(h+1:H) = nu(h+1:H)/(1 - nus(h));
            nus(h) = betarnd(1 + sum(ph==h), alpha + sum(ph>h));
            if nus(h)>1-1e-5
                nus(h)=1-1e-5;
            end   
            nu(h) = cc_f(h)*nus(h); nu(h+1:H) = cc_f(h+1:H)*(1-nus(h));
        end
    
    % 2)-- update ph: allocation to atoms -- %
    
      cols = data(:); 
      lin_idx = sub2ind([p,max(d)],rows,cols); 
        for h = 1:H
          
            phih = reshape(phi(h,:,:),p,max(d)); 
            tmpmatL = reshape(phih(lin_idx),[N,p]); 
            
            zupdateprob(:,h) = nu(h) * prod(tmpmatL,2);   
        
        end
        zupdateprob1 = bsxfun(@times,zupdateprob,1./(sum(zupdateprob,2)));
        mat1 = [zeros(N,1) cumsum(zupdateprob1,2)];
        rr = unifrnd(0,1,[N,1]);
        for l = 1:H
            ind = rr > mat1(:,l) & rr <= mat1(:,l+1); ph(ind) = l; 
        end
        
   
        % 3)-- update Phi -- %
    for h = 1:H
           
        nh = reshape(ph==h,[N 1]);
        
        for c = 1:max(d)
            mat(:,c) = (a.j(c) + sum(bsxfun(@times,(data==c),nh)))';
        end
            
        Lamh1 = gamrnd(mat,1); 
        for j=1:p
            Lamh = bsxfun(@times,Lamh1(j,1:d(j)), 1./sum(Lamh1(j,1:d(j))));
            phi(h,j,1:d(j)) = Lamh;
        end
    end
    
    % 4)-- update alpha -- %
      
    alpha = gamrnd(a0 + H-1, 1/(b0 - sum(log(1-nus(1:H-1)))));
  

    % 5) -- update beta_w -- %       
     W_x = [ones(N,1) data];

    VV=inv(eye(p+1)/Sigma0+W_x'* W_x);
    SS = (eye(p+1)/Sigma0+W_x'* W_x)\W_x'; 
    HM = W_x * SS; hdiag = diag(HM); wgh = hdiag./(1-hdiag);zt_v=1+wgh;

    zt=zeros(N,1);
    BB = (eye(p+1)/Sigma0+W_x'* W_x)\( W_x'* zt);

    for t_inc=1:500

    % a) -- update zt -- %

        zt_old=zt;
        mi=W_x * BB;
        mi=mi-wgh .* (zt_old-mi);
        af=normcdf(zeros(N,1), mi, sqrt(zt_v));
        af(af==1 & W==1) = 0.99999;
        af(af==0 & W==0) = 0.00001;
        ui=unifrnd(0,af).*(W==0)+unifrnd(af,1).*(W==1);
        zt=norminv(ui,mi,sqrt(zt_v));
        BB = BB + SS * (zt-zt_old);
   
    % b) -- update beta_w -- %
        beta_w =reshape(mvnrnd(BB,VV),[p+1,1]);

    end


     % 6)-- impute missing values of Y1 -- %

     y1_index = Np+1:N;
     
     for j=q+1:q+p1
         phi_y1 = reshape(phi(ph(Np+1:N),j,1:d(j)), [Nr d(j)]);
         
        mat_y1 = [zeros(Nr,1) cumsum(phi_y1,2)];
         r_y1 = unifrnd(0,1,[Nr,1]);
         for l=1:d(j)
             misind1 =r_y1 > mat_y1(:,l) & r_y1 <= mat_y1(:,l+1);
             data(y1_index(misind1),j)=l; 
         end
     end    
     

         
     % 7)-- impute missing values of Y2 -- %
      y2_index = Ncp+1:Np;
     
     for j=q+p1+1:p
         
         phi_y2 = reshape(phi(ph(Ncp+1:Np),j,1:d(j)), [Nip d(j)]);
         data_new=data;

        mat_y2 = [zeros(Nip,1) cumsum(phi_y2,2)];
        
         r_y2 = unifrnd(0,1,[Nip,1]);
         for l=1:d(j)
             misind2 =r_y2 > mat_y2(:,l) & r_y2 <= mat_y2(:,l+1);
             data_new(y2_index(misind2),j)=l; 
         end

jind=(rand(Nip,1) <= 1-normcdf([ones(Nip,1) data_new(Ncp+1:Np,:)]*beta_w));
       data(y2_index(jind),j)=data_new(y2_index(jind),j);
             
     end 
     
     % 8)-- impute missing values of W -- %
    

   W(Np+1:N)=(rand(Nr,1)<=normcdf([ones(Nr,1) data(Np+1:N,:)] * beta_w));
   
     
    % -- save sampled values (after thinning) -- %
    if mod(t,thin)==0 && t > burn
        
       cell_out((t-burn)/thin, :) = sum(bsxfun(@times,phi(:,:,1),nu'));   
       alpha_out((t-burn)/thin) = alpha;
       beta_w_out((t-burn)/thin,:) = beta_w;
       pi_out((t-burn)/thin,:) = nu;
       kout((t-burn)/thin) = length(unique(ph));
             
       y1_out((t-burn)/thin,:) = mean(data(:,1:5)-1);
       y2_out((t-burn)/thin,:) = mean(data(:,6:10)-1);
       w_out((t-burn)/thin) = mean(W);     
       
    end
    
  

    %--store multiple completed data sets ---%
       if sum(cut==t)>0
            imp=imp+1;
            impdata_out(re,((imp-1)*N+1):imp*N,1) = repmat(imp, [N 1]);
            impdata_out(re,((imp-1)*N+1):imp*N,2:(p+1)) = data;
      end
      
    if mod(t,1000)==0
        t
             
    end
    
end  %mcmc end


%%%%----------------------MI inference-------------------------------%%%%
for l=1:m
   
    data_l=reshape(impdata_out(re,impdata_out(re,:,1)==l,2:(p+1))-1,[N, p]);
    q_prob(l,:) = mean(data_l==1);
    u_prob(l,:) = mean(data_l==1).* mean(data_l==0)/N;
    for j=1:p1
        q_prob_j(l,j)=mean(data_l(:,j)==1 & data_l(:,p1+j)==1);
        u_prob_j(l,j)=q_prob_j(l,j)*(1-q_prob_j(l,j))/N;
    end 
    
    for j2=1:(p2-1)
         q_prob_j2(l,j2)=mean(data_l(:,j2+p1)==1 & data_l(:,p1+p2)==1);
         u_prob_j2(l,j2)=q_prob_j2(l,j2)*(1-q_prob_j2(l,j2))/N;
    end

end    


q_mi(re,:)=[mean(q_prob) mean(q_prob_j) mean(q_prob_j2)];
u_mi(re,:)=[mean(u_prob) mean(u_prob_j) mean(u_prob_j2)];
b_mi(re,:)=[var(q_prob) var(q_prob_j) var(q_prob_j2)];
t_mi(re,:)=(1+1/m)*b_mi(re,:) + u_mi(re,:);
df_mi=(m-1)*((1+u_mi(re,:)./b_mi(re,:)/(1+1/m)).^2);
low_mi(re,:)=q_mi(re,:)-tinv(0.975,df_mi).* sqrt(t_mi(re,:));
up_mi(re,:)=q_mi(re,:)+tinv(0.975,df_mi).* sqrt(t_mi(re,:));

for j=1:length(true_q)
    if low_mi(re,j)<= true_q(j) && true_q(j)<= up_mi(re,j)
   cover(re,j)=1; 
    end
end
     if mod(re,10)==0
        re    
     end
end


mean(cover)
mean(t_mi)
bias=zeros(1,length(true_q));
rmse=zeros(1,length(true_q));
for j=1:length(true_q)
    bias(j)=mean(q_mi(:,j)-true_q(j));
    rmse(j)=sqrt(sum((q_mi(:,j)-true_q(j)).^2)/Re);
end
csvwrite('output/simumlation-an.data', [mean(cover)' bias' rmse' mean(t_mi)'])
